"""Training and evaluation utilities for MK-CAViT (classification & segmentation)."""
import math
import random
import time
from typing import Dict, Optional, Tuple, Iterable

import numpy as np
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader


def set_seed(seed: int = 42) -> None:
    random.seed(seed); np.random.seed(seed); torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed); torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


def accuracy(logits: torch.Tensor, target: torch.Tensor, topk: Tuple[int, ...] = (1,)) -> Dict[str, float]:
    if target.ndim != 1:
        return {f'top{k}': 0.0 for k in topk}
    maxk = max(topk); B = target.size(0)
    _, pred = logits.topk(maxk, 1, True, True)   # (B,maxk)
    pred = pred.t()                               # (maxk,B)
    correct = pred.eq(target.view(1, -1).expand_as(pred))
    out = {}
    for k in topk:
        out[f'top{k}'] = correct[:k].reshape(-1).float().sum().mul_(100.0 / B).item()
    return out


def train_one_epoch(
    model: torch.nn.Module,
    loader: DataLoader,
    optimizer: torch.optim.Optimizer,
    device: torch.device,
    task: str = 'cls',     # 'cls' or 'multilabel'
    mu: float = 0.1,
    scaler: Optional[torch.cuda.amp.GradScaler] = None
) -> Dict[str, float]:
    model.train()
    total_loss = 0.0
    tic = time.time()
    for imgs, target in loader:
        imgs = imgs.to(device, non_blocking=True)
        if task == 'multilabel':
            target = target.to(device, non_blocking=True).float()
        else:
            target = target.to(device, non_blocking=True).long()

        optimizer.zero_grad(set_to_none=True)
        with torch.cuda.amp.autocast(enabled=(scaler is not None)):
            _, loss = model(imgs, labels=target, mu=mu)

        if scaler is not None:
            scaler.scale(loss).backward()
            scaler.step(optimizer); scaler.update()
        else:
            loss.backward(); optimizer.step()

        total_loss += loss.item() * imgs.size(0)

    return {'loss': total_loss / len(loader.dataset), 'time': time.time() - tic}


@torch.no_grad()
def evaluate_cls(
    model: torch.nn.Module,
    loader: DataLoader,
    device: torch.device,
    task: str = 'cls'
) -> Dict[str, float]:
    model.eval()
    total_loss, total_samples = 0.0, 0
    total_correct = 0.0
    for imgs, target in loader:
        imgs = imgs.to(device, non_blocking=True)
        if task == 'multilabel':
            target = target.to(device, non_blocking=True).float()
        else:
            target = target.to(device, non_blocking=True).long()

        logits = model(imgs)
        if task == 'multilabel':
            loss = F.binary_cross_entropy_with_logits(logits, target)
        else:
            loss = F.cross_entropy(logits, target)

        total_loss += loss.item() * imgs.size(0)
        total_samples += imgs.size(0)

        if task == 'cls':
            acc = accuracy(logits, target, (1,))
            total_correct += acc['top1'] * imgs.size(0) / 100.0

    out = {'loss': total_loss / max(total_samples, 1)}
    if task == 'cls':
        out['top1'] = 100.0 * total_correct / max(total_samples, 1)
    return out


@torch.no_grad()
def evaluate_seg(
    model: torch.nn.Module,
    loader: DataLoader,
    device: torch.device,
    num_classes: int = 150,
    ignore_index: int = 255
) -> Dict[str, float]:
    model.eval()
    total_loss, total_correct, total_pixels = 0.0, 0, 0
    for imgs, labels in loader:
        imgs = imgs.to(device, non_blocking=True)
        labels = labels.to(device, non_blocking=True)
        logits = model.forward_seg(imgs, out_size=labels.shape[-2:])
        loss = F.cross_entropy(logits, labels, ignore_index=ignore_index)
        total_loss += loss.item() * imgs.size(0)
        pred = logits.argmax(dim=1)
        mask = labels != ignore_index
        total_correct += (pred[mask] == labels[mask]).sum().item()
        total_pixels += mask.sum().item()
    return {'loss': total_loss / len(loader.dataset),
            'pixAcc': 100.0 * total_correct / max(total_pixels, 1)}


# --- COCO multi-label helpers ---

def coco_multilabel_collate(batch: Iterable):
    """
    Collate function for CocoDetection -> (image, multi-hot vector).
    Each sample is (img, targets) where targets is a list of dicts with 'category_id'.
    """
    imgs, ys = [], []
    for img, anns in batch:
        imgs.append(img)
        labels = [ann['category_id'] for ann in anns]
        ys.append(labels)
    # create multi-hot
    B = len(imgs); C = 80
    Y = torch.zeros(B, C, dtype=torch.float32)
    for i, labs in enumerate(ys):
        for c in labs:
            if 0 <= c < C:
                Y[i, c] = 1.0
    return torch.stack(imgs), Y
